{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### bayesian network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 导入pgmpy相关模块\n",
    "from pgmpy.factors.discrete import TabularCPD\n",
    "from pgmpy.models import BayesianModel\n",
    "letter_model = BayesianModel([('D', 'G'),\n",
    "                               ('I', 'G'),\n",
    "                               ('G', 'L'),\n",
    "                               ('I', 'S')])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 学生成绩的条件概率分布\n",
    "grade_cpd = TabularCPD(\n",
    "    variable='G', # 节点名称\n",
    "    variable_card=3, # 节点取值个数\n",
    "    values=[[0.3, 0.05, 0.9, 0.5], # 该节点的概率表\n",
    "    [0.4, 0.25, 0.08, 0.3],\n",
    "    [0.3, 0.7, 0.02, 0.2]],\n",
    "    evidence=['I', 'D'], # 该节点的依赖节点\n",
    "    evidence_card=[2, 2] # 依赖节点的取值个数\n",
    ")\n",
    "# 考试难度的条件概率分布\n",
    "difficulty_cpd = TabularCPD(\n",
    "            variable='D',\n",
    "            variable_card=2,\n",
    "            values=[[0.6], [0.4]]\n",
    ")\n",
    "# 个人天赋的条件概率分布\n",
    "intel_cpd = TabularCPD(\n",
    "            variable='I',\n",
    "            variable_card=2,\n",
    "            values=[[0.7], [0.3]]\n",
    ")\n",
    "# 推荐信质量的条件概率分布\n",
    "letter_cpd = TabularCPD(\n",
    "            variable='L',\n",
    "            variable_card=2,\n",
    "            values=[[0.1, 0.4, 0.99],\n",
    "            [0.9, 0.6, 0.01]],\n",
    "            evidence=['G'],\n",
    "            evidence_card=[3]\n",
    ")\n",
    "# SAT考试分数的条件概率分布\n",
    "sat_cpd = TabularCPD(\n",
    "            variable='S',\n",
    "            variable_card=2,\n",
    "            values=[[0.95, 0.2],\n",
    "            [0.05, 0.8]],\n",
    "            evidence=['I'],\n",
    "            evidence_card=[2]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:root:Replacing existing CPD for G\n",
      "WARNING:root:Replacing existing CPD for D\n",
      "WARNING:root:Replacing existing CPD for I\n",
      "WARNING:root:Replacing existing CPD for L\n",
      "WARNING:root:Replacing existing CPD for S\n",
      "Finding Elimination Order: : 100%|██████████████████████████████████████████████████████| 2/2 [00:00<00:00, 668.95it/s]\n",
      "Eliminating: L: 100%|███████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 285.72it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+------+----------+\n",
      "| G    |   phi(G) |\n",
      "+======+==========+\n",
      "| G(0) |   0.9000 |\n",
      "+------+----------+\n",
      "| G(1) |   0.0800 |\n",
      "+------+----------+\n",
      "| G(2) |   0.0200 |\n",
      "+------+----------+\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# 将各节点添加到模型中，构建贝叶斯网络\n",
    "letter_model.add_cpds(\n",
    "    grade_cpd, \n",
    "    difficulty_cpd,\n",
    "    intel_cpd,\n",
    "    letter_cpd,\n",
    "    sat_cpd\n",
    ")\n",
    "# 导入pgmpy贝叶斯推断模块\n",
    "from pgmpy.inference import VariableElimination\n",
    "# 贝叶斯网络推断\n",
    "letter_infer = VariableElimination(letter_model)\n",
    "# 天赋较好且考试不难的情况下推断该学生获得推荐信质量的好坏\n",
    "prob_G = letter_infer.query(\n",
    "            variables=['G'],\n",
    "            evidence={'I': 1, 'D': 0})\n",
    "print(prob_G)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
